{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Download the IMDB Dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 30,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Download reviews.txt and labels.txt from here: https://github.com/udacity/deep-learning/tree/master/sentiment-network\n",
    "\n",
    "def pretty_print_review_and_label(i):\n",
    "   print(labels[i] + \"\\t:\\t\" + reviews[i][:80] + \"...\")\n",
    "\n",
    "g = open('reviews.txt','r') # What we know!\n",
    "reviews = list(map(lambda x:x[:-1],g.readlines()))\n",
    "g.close()\n",
    "\n",
    "g = open('labels.txt','r') # What we WANT to know!\n",
    "labels = list(map(lambda x:x[:-1].upper(),g.readlines()))\n",
    "g.close()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Capturing Word Correlation in Input Data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 31,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Sent Encoding:[1 1 0 1]\n"
     ]
    }
   ],
   "source": [
    "import numpy as np\n",
    "\n",
    "onehots = {}\n",
    "onehots['cat'] = np.array([1,0,0,0])\n",
    "onehots['the'] = np.array([0,1,0,0])\n",
    "onehots['dog'] = np.array([0,0,1,0])\n",
    "onehots['sat'] = np.array([0,0,0,1])\n",
    "\n",
    "sentence = ['the','cat','sat']\n",
    "x = onehots[sentence[0]] + \\\n",
    "    onehots[sentence[1]] + \\\n",
    "    onehots[sentence[2]]\n",
    "\n",
    "print(\"Sent Encoding:\" + str(x))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Predicting Movie Reviews"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 56,
   "metadata": {},
   "outputs": [],
   "source": [
    "import sys\n",
    "\n",
    "f = open('reviews.txt')\n",
    "raw_reviews = f.readlines()\n",
    "f.close()\n",
    "\n",
    "f = open('labels.txt')\n",
    "raw_labels = f.readlines()\n",
    "f.close()\n",
    "\n",
    "tokens = list(map(lambda x:set(x.split(\" \")),raw_reviews))\n",
    "\n",
    "vocab = set()\n",
    "for sent in tokens:\n",
    "    for word in sent:\n",
    "        if(len(word)>0):\n",
    "            vocab.add(word)\n",
    "vocab = list(vocab)\n",
    "\n",
    "word2index = {}\n",
    "for i,word in enumerate(vocab):\n",
    "    word2index[word]=i\n",
    "\n",
    "input_dataset = list()\n",
    "for sent in tokens:\n",
    "    sent_indices = list()\n",
    "    for word in sent:\n",
    "        try:\n",
    "            sent_indices.append(word2index[word])\n",
    "        except:\n",
    "            \"\"\n",
    "    input_dataset.append(list(set(sent_indices)))\n",
    "\n",
    "target_dataset = list()\n",
    "for label in raw_labels:\n",
    "    if label == 'positive\\n':\n",
    "        target_dataset.append(1)\n",
    "    else:\n",
    "        target_dataset.append(0)"
   ]
  },
  {
   "cell_type": "raw",
   "metadata": {},
   "source": [
    "import numpy as np\n",
    "np.random.seed(1)\n",
    "\n",
    "def sigmoid(x):\n",
    "    return 1/(1 + np.exp(-x))\n",
    "\n",
    "alpha, iterations = (0.01, 2)\n",
    "hidden_size = 100\n",
    "\n",
    "weights_0_1 = 0.2*np.random.random((len(vocab),hidden_size)) - 0.1\n",
    "weights_1_2 = 0.2*np.random.random((hidden_size,1)) - 0.1\n",
    "\n",
    "correct,total = (0,0)\n",
    "for iter in range(iterations):\n",
    "    \n",
    "    # train on first 24,000\n",
    "    for i in range(len(input_dataset)-1000):\n",
    "\n",
    "        x,y = (input_dataset[i],target_dataset[i])\n",
    "        layer_1 = sigmoid(np.sum(weights_0_1[x],axis=0)) #embed + sigmoid\n",
    "        layer_2 = sigmoid(np.dot(layer_1,weights_1_2)) # linear + softmax\n",
    "\n",
    "        layer_2_delta = layer_2 - y # compare pred with truth\n",
    "        layer_1_delta = layer_2_delta.dot(weights_1_2.T) #backprop\n",
    "\n",
    "        weights_0_1[x] -= layer_1_delta * alpha\n",
    "        weights_1_2 -= np.outer(layer_1,layer_2_delta) * alpha\n",
    "\n",
    "        if(np.abs(layer_2_delta) < 0.5):\n",
    "            correct += 1\n",
    "        total += 1\n",
    "        if(i % 10 == 9):\n",
    "            progress = str(i/float(len(input_dataset)))\n",
    "            sys.stdout.write('\\rIter:'+str(iter)\\\n",
    "                             +' Progress:'+progress[2:4]\\\n",
    "                             +'.'+progress[4:6]\\\n",
    "                             +'% Training Accuracy:'\\\n",
    "                             + str(correct/float(total)) + '%')\n",
    "    print()\n",
    "correct,total = (0,0)\n",
    "for i in range(len(input_dataset)-1000,len(input_dataset)):\n",
    "\n",
    "    x = input_dataset[i]\n",
    "    y = target_dataset[i]\n",
    "\n",
    "    layer_1 = sigmoid(np.sum(weights_0_1[x],axis=0))\n",
    "    layer_2 = sigmoid(np.dot(layer_1,weights_1_2))\n",
    "    \n",
    "    if(np.abs(layer_2 - y) < 0.5):\n",
    "        correct += 1\n",
    "    total += 1\n",
    "print(\"Test Accuracy:\" + str(correct / float(total)))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 31,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'',\n",
       " '\\n',\n",
       " '.',\n",
       " 'a',\n",
       " 'about',\n",
       " 'adults',\n",
       " 'age',\n",
       " 'all',\n",
       " 'and',\n",
       " 'as',\n",
       " 'at',\n",
       " 'believe',\n",
       " 'bromwell',\n",
       " 'burn',\n",
       " 'can',\n",
       " 'cartoon',\n",
       " 'classic',\n",
       " 'closer',\n",
       " 'comedy',\n",
       " 'down',\n",
       " 'episode',\n",
       " 'expect',\n",
       " 'far',\n",
       " 'fetched',\n",
       " 'financially',\n",
       " 'here',\n",
       " 'high',\n",
       " 'i',\n",
       " 'immediately',\n",
       " 'in',\n",
       " 'insightful',\n",
       " 'inspector',\n",
       " 'is',\n",
       " 'isn',\n",
       " 'it',\n",
       " 'knew',\n",
       " 'lead',\n",
       " 'life',\n",
       " 'line',\n",
       " 'm',\n",
       " 'many',\n",
       " 'me',\n",
       " 'much',\n",
       " 'my',\n",
       " 'of',\n",
       " 'one',\n",
       " 'other',\n",
       " 'pathetic',\n",
       " 'pettiness',\n",
       " 'pity',\n",
       " 'pomp',\n",
       " 'profession',\n",
       " 'programs',\n",
       " 'ran',\n",
       " 'reality',\n",
       " 'recalled',\n",
       " 'remind',\n",
       " 'repeatedly',\n",
       " 'right',\n",
       " 's',\n",
       " 'sack',\n",
       " 'same',\n",
       " 'satire',\n",
       " 'saw',\n",
       " 'school',\n",
       " 'schools',\n",
       " 'scramble',\n",
       " 'see',\n",
       " 'situation',\n",
       " 'some',\n",
       " 'student',\n",
       " 'students',\n",
       " 'such',\n",
       " 'survive',\n",
       " 't',\n",
       " 'teachers',\n",
       " 'teaching',\n",
       " 'than',\n",
       " 'that',\n",
       " 'the',\n",
       " 'their',\n",
       " 'think',\n",
       " 'through',\n",
       " 'time',\n",
       " 'to',\n",
       " 'tried',\n",
       " 'welcome',\n",
       " 'what',\n",
       " 'when',\n",
       " 'which',\n",
       " 'who',\n",
       " 'whole',\n",
       " 'years',\n",
       " 'your'}"
      ]
     },
     "execution_count": 31,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "tokens[0]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Comparing Word Embeddings"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 61,
   "metadata": {},
   "outputs": [],
   "source": [
    "from collections import Counter\n",
    "import math \n",
    "\n",
    "def similar(target='beautiful'):\n",
    "    target_index = word2index[target]\n",
    "    scores = Counter()\n",
    "    for word,index in word2index.items():\n",
    "        raw_difference = weights_0_1[index] - (weights_0_1[target_index])\n",
    "        squared_difference = raw_difference * raw_difference\n",
    "        scores[word] = -math.sqrt(sum(squared_difference))\n",
    "\n",
    "    return scores.most_common(10)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 64,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[('beautiful', -0.0), ('heart', -0.7461901055360456), ('captures', -0.7767713774499612), ('impact', -0.7851006592549541), ('unexpected', -0.8024296074764704), ('bit', -0.8041029062033365), ('touching', -0.8041105203290175), ('true', -0.8092335336931215), ('worth', -0.8095649927927353), ('strong', -0.8095814455120289)]\n"
     ]
    }
   ],
   "source": [
    "print(similar('beautiful'))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 65,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[('terrible', -0.0), ('boring', -0.7591663900380615), ('lame', -0.7732283645546325), ('horrible', -0.788081854105546), ('disappointing', -0.7893120726668719), ('avoid', -0.7939105009456955), ('badly', -0.8054784389155504), ('annoying', -0.8067172753479477), ('dull', -0.8072650189634973), ('mess', -0.8139036459320503)]\n"
     ]
    }
   ],
   "source": [
    "print(similar('terrible'))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Filling in the Blank"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 66,
   "metadata": {},
   "outputs": [],
   "source": [
    "import sys,random,math\n",
    "from collections import Counter\n",
    "import numpy as np\n",
    "\n",
    "np.random.seed(1)\n",
    "random.seed(1)\n",
    "f = open('reviews.txt')\n",
    "raw_reviews = f.readlines()\n",
    "f.close()\n",
    "\n",
    "tokens = list(map(lambda x:(x.split(\" \")),raw_reviews))\n",
    "wordcnt = Counter()\n",
    "for sent in tokens:\n",
    "    for word in sent:\n",
    "        wordcnt[word] -= 1\n",
    "vocab = list(set(map(lambda x:x[0],wordcnt.most_common())))\n",
    "\n",
    "word2index = {}\n",
    "for i,word in enumerate(vocab):\n",
    "    word2index[word]=i\n",
    "\n",
    "concatenated = list()\n",
    "input_dataset = list()\n",
    "for sent in tokens:\n",
    "    sent_indices = list()\n",
    "    for word in sent:\n",
    "        try:\n",
    "            sent_indices.append(word2index[word])\n",
    "            concatenated.append(word2index[word])\n",
    "        except:\n",
    "            \"\"\n",
    "    input_dataset.append(sent_indices)\n",
    "concatenated = np.array(concatenated)\n",
    "random.shuffle(input_dataset)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 69,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Progress:0.99998[('terrible', -0.0), ('horrible', -3.488841411481131), ('bad', -4.0636425093941595), ('brilliant', -4.211247495138625), ('pathetic', -4.304645745396163), ('fantastic', -4.341998952418319), ('fabulous', -4.356925869405997), ('phenomenal', -4.361301237074382), ('marvelous', -4.3856957968039145), ('spectacular', -4.413156799233535)]\n"
     ]
    }
   ],
   "source": [
    "alpha, iterations = (0.05, 2)\n",
    "hidden_size,window,negative = (50,2,5)\n",
    "\n",
    "weights_0_1 = (np.random.rand(len(vocab),hidden_size) - 0.5) * 0.2\n",
    "weights_1_2 = np.random.rand(len(vocab),hidden_size)*0\n",
    "\n",
    "layer_2_target = np.zeros(negative+1)\n",
    "layer_2_target[0] = 1\n",
    "\n",
    "def similar(target='beautiful'):\n",
    "  target_index = word2index[target]\n",
    "\n",
    "  scores = Counter()\n",
    "  for word,index in word2index.items():\n",
    "    raw_difference = weights_0_1[index] - (weights_0_1[target_index])\n",
    "    squared_difference = raw_difference * raw_difference\n",
    "    scores[word] = -math.sqrt(sum(squared_difference))\n",
    "  return scores.most_common(10)\n",
    "\n",
    "def sigmoid(x):\n",
    "    return 1/(1 + np.exp(-x))\n",
    "\n",
    "for rev_i,review in enumerate(input_dataset * iterations):\n",
    "  for target_i in range(len(review)):\n",
    "        \n",
    "    # since it's really expensive to predict every vocabulary\n",
    "    # we're only going to predict a random subset\n",
    "    target_samples = [review[target_i]]+list(concatenated\\\n",
    "    [(np.random.rand(negative)*len(concatenated)).astype('int').tolist()])\n",
    "\n",
    "    left_context = review[max(0,target_i-window):target_i]\n",
    "    right_context = review[target_i+1:min(len(review),target_i+window)]\n",
    "\n",
    "    layer_1 = np.mean(weights_0_1[left_context+right_context],axis=0)\n",
    "    layer_2 = sigmoid(layer_1.dot(weights_1_2[target_samples].T))\n",
    "    layer_2_delta = layer_2 - layer_2_target\n",
    "    layer_1_delta = layer_2_delta.dot(weights_1_2[target_samples])\n",
    "\n",
    "    weights_0_1[left_context+right_context] -= layer_1_delta * alpha\n",
    "    weights_1_2[target_samples] -= np.outer(layer_2_delta,layer_1)*alpha\n",
    "\n",
    "  if(rev_i % 250 == 0):\n",
    "    sys.stdout.write('\\rProgress:'+str(rev_i/float(len(input_dataset)\n",
    "        *iterations)) + \"   \" + str(similar('terrible')))\n",
    "  sys.stdout.write('\\rProgress:'+str(rev_i/float(len(input_dataset)\n",
    "        *iterations)))\n",
    "print(similar('terrible'))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# King - Man + Woman ~= Queen"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 70,
   "metadata": {},
   "outputs": [],
   "source": [
    "def analogy(positive=['terrible','good'],negative=['bad']):\n",
    "    \n",
    "    norms = np.sum(weights_0_1 * weights_0_1,axis=1)\n",
    "    norms.resize(norms.shape[0],1)\n",
    "    \n",
    "    normed_weights = weights_0_1 * norms\n",
    "    \n",
    "    query_vect = np.zeros(len(weights_0_1[0]))\n",
    "    for word in positive:\n",
    "        query_vect += normed_weights[word2index[word]]\n",
    "    for word in negative:\n",
    "        query_vect -= normed_weights[word2index[word]]\n",
    "    \n",
    "    scores = Counter()\n",
    "    for word,index in word2index.items():\n",
    "        raw_difference = weights_0_1[index] - query_vect\n",
    "        squared_difference = raw_difference * raw_difference\n",
    "        scores[word] = -math.sqrt(sum(squared_difference))\n",
    "        \n",
    "    return scores.most_common(10)[1:]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 71,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[('terrific', -210.46593317724228),\n",
       " ('perfect', -210.52652806032205),\n",
       " ('worth', -210.53162266358495),\n",
       " ('good', -210.55072184482773),\n",
       " ('terrible', -210.58429046605724),\n",
       " ('decent', -210.87945442008805),\n",
       " ('superb', -211.01143515971094),\n",
       " ('great', -211.1327058081335),\n",
       " ('worthy', -211.13577238103477)]"
      ]
     },
     "execution_count": 71,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "analogy(['terrible','good'],['bad'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 72,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[('simon', -193.82490698964878),\n",
       " ('obsessed', -193.91805919583555),\n",
       " ('stanwyck', -194.22311983847902),\n",
       " ('sandler', -194.22846640800597),\n",
       " ('branagh', -194.24551334589853),\n",
       " ('daniel', -194.24631020485714),\n",
       " ('peter', -194.29908544092078),\n",
       " ('tony', -194.31388897167716),\n",
       " ('aged', -194.35115773165094)]"
      ]
     },
     "execution_count": 72,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "analogy(['elizabeth','he'],['she'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.1"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
